{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BO7MEGbb6mtB"
      },
      "source": [
        "# Generate text with RuGPTs in huggingface\n",
        "How to generate text with pretrained RuGPTs models with huggingface.\n",
        "\n",
        "This notebook is valid for all RuGPTs models except RuGPT3XL.\n",
        "## Install env"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H73-Pizb6c8n"
      },
      "outputs": [],
      "source": [
        "!pip3 install transformers"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QvgntLymArg3"
      },
      "source": [
        "## Generate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "csHcDJXFDdaW"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import torch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TJxPg-cJDhAB"
      },
      "outputs": [],
      "source": [
        "np.random.seed(42)\n",
        "torch.manual_seed(42)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "AkUrzKsy_16F"
      },
      "outputs": [],
      "source": [
        "from transformers import GPT2LMHeadModel, GPT2Tokenizer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "tV7tt-t2FQc3"
      },
      "outputs": [],
      "source": [
        "def load_tokenizer_and_model(model_name_or_path):\n",
        "  return GPT2Tokenizer.from_pretrained(model_name_or_path), GPT2LMHeadModel.from_pretrained(model_name_or_path).cuda()\n",
        "\n",
        "\n",
        "def generate(\n",
        "    model, tok, text,\n",
        "    do_sample=True, max_length=50, repetition_penalty=5.0,\n",
        "    top_k=5, top_p=0.95, temperature=1,\n",
        "    num_beams=None,\n",
        "    no_repeat_ngram_size=3\n",
        "    ):\n",
        "  input_ids = tok.encode(text, return_tensors=\"pt\").cuda()\n",
        "  out = model.generate(\n",
        "      input_ids.cuda(),\n",
        "      max_length=max_length,\n",
        "      repetition_penalty=repetition_penalty,\n",
        "      do_sample=do_sample,\n",
        "      top_k=top_k, top_p=top_p, temperature=temperature,\n",
        "      num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size\n",
        "      )\n",
        "  return list(map(tok.decode, out))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7sPySei8FO_r"
      },
      "source": [
        "### RuGPT2Large"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x_EMbgO0BTvb"
      },
      "outputs": [],
      "source": [
        "tok, model = load_tokenizer_and_model(\"sberbank-ai/rugpt2large\")\n",
        "generated = generate(model, tok, \"Александр Сергеевич Пушкин родился в \", num_beams=10)\n",
        "generated[0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F4X-d7fIIZFC"
      },
      "source": [
        "### RuGPT3Small"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "24oUrAfBIk6G"
      },
      "outputs": [],
      "source": [
        "tok, model = load_tokenizer_and_model(\"sberbank-ai/rugpt3small_based_on_gpt2\")\n",
        "generated = generate(model, tok, \"Александр Сергеевич Пушкин родился в \", num_beams=10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 53
        },
        "id": "SGTZin-JIu_N",
        "outputId": "52795a45-12ef-47f8-e7f9-84f10077f986"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'Александр Сергеевич Пушкин родился в  1825 г. в семье поэта Александра Сергеевича Пушкина и его жены Александры Николаевны Пушкиной (урожденной Пушкиных). В 1783 г. он поступил на юридический факультет Санкт-Петербургского университета'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 9
        }
      ],
      "source": [
        "generated[0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GHrO9tovIyyj"
      },
      "source": [
        "### RuGPT3Medium"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2MVyT8zAIyys"
      },
      "outputs": [],
      "source": [
        "tok, model = load_tokenizer_and_model(\"sberbank-ai/rugpt3medium_based_on_gpt2\")\n",
        "generated = generate(model, tok, \"Александр Сергеевич Пушкин родился в \", num_beams=10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 53
        },
        "id": "W3SWmttlJHF7",
        "outputId": "d4e97e47-3ac0-4072-f9b2-bc0aca2b802c"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'Александр Сергеевич Пушкин родился в  1799 году, умер в 1837-м. Он был одним из самых образованных и одаренных людей своего времени. У него было много увлечений: он увлекался математикой, физикой, астрономией,'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 11
        }
      ],
      "source": [
        "generated[0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HnU-9k3dIzfy"
      },
      "source": [
        "### RuGPT3Large"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z14U66yuIzfz"
      },
      "outputs": [],
      "source": [
        "tok, model = load_tokenizer_and_model(\"sberbank-ai/rugpt3large_based_on_gpt2\")\n",
        "generated = generate(model, tok, \"Александр Сергеевич Пушкин родился в \", num_beams=10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 53
        },
        "id": "VFuy-V2xJmwu",
        "outputId": "c50acf5d-df76-4b06-a325-14a7148a24ee"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'Александр Сергеевич Пушкин родился в \\n1799 году. Его отец был крепостным крестьянином, а мать – крепостной крестьянкой. Детство и юность поэта прошли в селе Михайловском Пензенской губернии. В 1820-х годах семья переехала'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 13
        }
      ],
      "source": [
        "generated[0]"
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "WCfz5Cs5ENOo"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "Generate_text_with_RuGPTs_HF",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.8"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}